import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import pylab

'''
自动下载 mnist 数据集并解压到 mnist_data 目录下

如果程序下载慢，就上网把数据集下载下来，放到 mnist_data 文件夹下
t10k-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz

代码中的 one_hot=True, 表示将样本标签转化为 one_hot 编码
例： 假设一共 10 类。 0的 one_hot 编码为 1000000000，1的 one_hot 编码为 0100000000，
    2的 one_hot 编码为 0010000000 ...... 依次类推，1所在的位置就代表着几类。
'''
mnist = input_data.read_data_sets('mnist_data/', one_hot=True)
'''
MNIST 数据集中的图片是 28x28 Pixel，所以一幅图就有 1x728 (28x28) 的数据，
如果是黑白图片的话，黑色地方数值为 0，有图案的地方，数值为 0~255 之间数字，代表颜色的深度
如果是彩色图片的话，一个像素会有3个值来表示 RGB 
'''
# 测试数据集
print('input_data : ', mnist.train.images)
print('input_data_shape : ', mnist.train.images.shape)  # input_data_shape :  (55000, 784)
# 从 55000x784 的矩阵中返回第二个图像的向量
im = mnist.train.images[1]
# 重塑取出的向量，reshape 为 28x28 的矩阵
im = im.reshape(-1, 28)
# 将矩阵图像化
pylab.imshow(im)
# 显示图像
pylab.show()


'''
输入数据集是一个 55000x784 的矩阵，所以先创建一个[None, 784] 的占位符 images 和一个 [None, 10]的占位符 labels 
然后使用 feed 机制将图片和标签输入进去
'''
# 定义变量
tf.reset_default_graph()
images = tf.placeholder(tf.float32, [None, 784], name='images')
labels = tf.placeholder(tf.float32, [None, 10], name='labels')
# 定义学习参数
weights = tf.Variable(tf.random_normal([784, 10]))  # weights 的维度是[784, 10]
biases = tf.Variable(tf.zeros([10]))                # biases 的shape是(10, )

# 构建模型
z = tf.matmul(images, weights) + biases
# softmax ：输出的是一个多维向量，不论有多少个分量，其加和都是1，每个向量的分量维度是小于1的值，而这个值可以做概率解释的
pred = tf.nn.softmax(z)

# 定义方向传播结构，优化参数
# 定义损失函数
cost = tf.reduce_mean(-tf.reduce_sum(labels*tf.log(pred), reduction_indices=1))
# 定义学习参数
learning_rate = 0.01
# 使用梯度下降优化参数
train = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

# 定义保存模型的参数
model_path = 'model/mnist_model.ckpt'
saver = tf.train.Saver()

# 训练模型
'''
training_epochs 训练样本迭代次数
batch_size  训练过程中每次取出训练的数据量 
'''
training_epochs = 25
batch_size = 100
display_step = 1

# 启动session
with tf.Session() as sess_save:
    # 初始化参数
    sess_save.run(tf.global_variables_initializer())
    # 启动训练循环
    for epoch in range(training_epochs):
        avg_cost = 0
        total_batch = int(mnist.train.num_examples/batch_size)
        for _ in range(total_batch):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            # 传参优化
            _, cost_result = sess_save.run([train, cost], feed_dict={images: batch_x, labels: batch_y})
            # 计算平均loss值
            avg_cost += cost_result/total_batch
        if (epoch + 1) % display_step == 0:
            print('Epoch: ', '%04d' % (epoch + 1), 'cost: ', '{:.9f}'.format(avg_cost))

    # 保存模型
    saver.save(sess_save, model_path)

    print('Finished!')

    # 测试模型
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(labels, 1))
    # 计算准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print('Accuracy: ', accuracy.eval({images: mnist.test.images, labels: mnist.test.labels}))


# 读取模型
print('下面是读取模型，并输入测试: ')
with tf.Session() as sess_restore:
    # 初始化全局变量
    sess_restore.run(tf.global_variables_initializer())
    # 载入模型
    saver.restore(sess_restore, model_path)

    # 测试模型
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(labels, 1))
    # 计算准确率
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print('Accuracy: ', accuracy.eval({images: mnist.test.images, labels: mnist.test.labels}))

    # 输出预测的labels向量中值为 最大概率值 的下标
    output = tf.argmax(pred, 1)
    # 选取两张图片进行预测
    batch_x, batch_y = mnist.train.next_batch(2)
    output_val, pred_v = sess_restore.run([output, pred], feed_dict={images: batch_x, labels: batch_y})
    print(output_val, pred_v, batch_y)

    im = batch_x[0]
    im = im.reshape(-1, 28)
    pylab.imshow(im)
    pylab.show()

    im = batch_x[1]
    im = im.reshape(-1, 28)
    pylab.imshow(im)
    pylab.show()